from DataGeneration.DataGeneration import GenerateData
import numpy as np
import matplotlib.pyplot as plt

N = 5000
sigma = 1
myseed = 0
collecting_rule = lambda X: np.random.binomial(1,0.5,size=X.shape[0])  # return the probability of having action 1
eyx1 = lambda X, A: (+X[:, 0] + X[:, 1] + 4 * A * (X[:, 0].__abs__())*(X[:,1].__abs__())*(((X[:,1]>0)&(X[:,0]>0))-0.5))
DG = GenerateData(Ydist='normal', eyx=eyx1, sigma=sigma, collecting_rule=collecting_rule, k=2, N=N, seed=myseed)
X, A, Ey, Y = DG.GenerateData()

X_evaluation = np.array([np.meshgrid(np.linspace(-1, 1, 100), np.linspace(-1, 1, 100))[0].reshape((-1)),
                         np.meshgrid(np.linspace(-1, 1, 100), np.linspace(-1, 1, 100))[1].reshape((-1))]).swapaxes(0, 1)
trt_effect = eyx1(X_evaluation,1) - eyx1(X_evaluation,0)

fig, ax = plt.subplots(figsize=(8,5))
contour = ax.contourf(np.linspace(-1, 1, 100),
    np.linspace(-1, 1, 100), trt_effect.reshape((100,100)), levels=20, cmap='coolwarm')
contour_lines = ax.contour(np.linspace(-1, 1, 100), np.linspace(-1, 1, 100),
    trt_effect.reshape((100,100)), levels=20, colors='black', linewidths=0.5)
ax.clabel(contour_lines, inline=True, fontsize=8, fmt='%.2f')
ax.text(-0.35,0, 'A', fontsize=12, color='black', zorder=6,
             bbox=dict(facecolor='white', alpha=0.1, edgecolor='none', boxstyle='round,pad=0.2'))

ax.text(0.64,0.4, 'B', fontsize=12, color='black', zorder=6,
             bbox=dict(facecolor='white', alpha=0.1, edgecolor='none', boxstyle='round,pad=0.2'))
ax.scatter([0.6, -0.4], [0.5, 0.1], c='black', s=50, zorder=10)

ax.set_xlabel('$X_1$')
ax.set_ylabel('$X_2$')
# ax.set_title("Ground truth CATE")
fig.colorbar(contour)
fig.tight_layout()
fig.show()
fig.savefig("Figure2/Ground_truth_CATE.pdf")



plt.xlabel('Feature X1')
plt.ylabel('Feature X2')
plt.title('Treatment Effect across Feature Space')

# Set aspect ratio to be equal if desired
plt.gca().set_aspect('equal', adjustable='box')


plt.figure(figsize=(10, 8))
plt.scatter(X[:, 0], X[:, 1], c=trt_effect, cmap='viridis', s=5)